{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tuning using LLM Engine"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fine-tuning helps improve model performance by training on specific examples of prompts and desired responses. LLMs are initially trained on data collected from the entire internet. With fine-tuning, LLMs can be optimized to perform better in a specific domain by learning from examples for that domain. Smaller LLMs that have been fine-tuned on a specific use case often outperform larger ones that were trained more generally.\n",
    "\n",
    "In this notebook, we will demonstrate fine-tuning open source models in order to classify emails into two categories, based on their content."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing the Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will demonstrate fine-tuning open source models in order to classify emails into two categories, based on their content.\n",
    "\n",
    "We will prepare 950 examples to fine-tune on, and use 50 examples to test the performance of our fine-tune. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups\n",
    "import pandas as pd\n",
    "\n",
    "categories = ['rec.sport.baseball', 'rec.sport.hockey']\n",
    "sports_dataset = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories=categories)\n",
    "\n",
    "labels = [sports_dataset.target_names[x].split('.')[-1] for x in sports_dataset['target']]\n",
    "texts = [text.strip() for text in sports_dataset['data']]\n",
    "df = pd.DataFrame(zip(texts, labels), columns = ['raw_prompt','response'])[:1000]\n",
    "df_train = df[:950]\n",
    "df_test = df[950:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first take a look at our data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"From: dougb@comm.mot.com (Doug Bank)\\nSubject: Re: Info needed for Cleveland tickets\\nReply-To: dougb@ecs.comm.mot.com\\nOrganization: Motorola Land Mobile Products Sector\\nDistribution: usa\\nNntp-Posting-Host: 145.1.146.35\\nLines: 17\\n\\nIn article <1993Apr1.234031.4950@leland.Stanford.EDU>, bohnert@leland.Stanford.EDU (matthew bohnert) writes:\\n\\n|> I'm going to be in Cleveland Thursday, April 15 to Sunday, April 18.\\n|> Does anybody know if the Tribe will be in town on those dates, and\\n|> if so, who're they playing and if tickets are available?\\n\\nThe tribe will be in town from April 16 to the 19th.\\nThere are ALWAYS tickets available! (Though they are playing Toronto,\\nand many Toronto fans make the trip to Cleveland as it is easier to\\nget tickets in Cleveland than in Toronto.  Either way, I seriously\\ndoubt they will sell out until the end of the season.)\\n\\n-- \\nDoug Bank                       Private Systems Division\\ndougb@ecs.comm.mot.com          Motorola Communications Sector\\ndougb@nwu.edu                   Schaumburg, Illinois\\ndougb@casbah.acns.nwu.edu       708-576-8207\""
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train['raw_prompt'].iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "response\n",
       "baseball    498\n",
       "hockey      452\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train['response'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "response\n",
       "baseball    25\n",
       "hockey      25\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test['response'].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we are training a text generation model, let's do a bit of (extremely basic) prompt engineering to use the model for classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_prompt(text: str):\n",
    "    return f\"Prompt: {text}\\nCategory: \"\n",
    "\n",
    "def prepare_df(df: pd.DataFrame):\n",
    "    # df['prompt'] = df.apply(lambda row: build_prompt(row['raw_prompt']), axis=1)\n",
    "    df['prompt'] = df['raw_prompt'].apply(build_prompt)\n",
    "    df.drop('raw_prompt', axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prepare_df(df_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>response</th>\n",
       "      <th>prompt</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>baseball</td>\n",
       "      <td>Prompt: From: dougb@comm.mot.com (Doug Bank)\\n...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>hockey</td>\n",
       "      <td>Prompt: From: gld@cunixb.cc.columbia.edu (Gary...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>baseball</td>\n",
       "      <td>Prompt: From: rudy@netcom.com (Rudy Wade)\\nSub...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>hockey</td>\n",
       "      <td>Prompt: From: monack@helium.gas.uug.arizona.ed...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>baseball</td>\n",
       "      <td>Prompt: Subject: Let it be Known\\nFrom: &lt;ISSBT...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   response                                             prompt\n",
       "0  baseball  Prompt: From: dougb@comm.mot.com (Doug Bank)\\n...\n",
       "1    hockey  Prompt: From: gld@cunixb.cc.columbia.edu (Gary...\n",
       "2  baseball  Prompt: From: rudy@netcom.com (Rudy Wade)\\nSub...\n",
       "3    hockey  Prompt: From: monack@helium.gas.uug.arizona.ed...\n",
       "4  baseball  Prompt: Subject: Let it be Known\\nFrom: <ISSBT..."
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The data needs to end up in a CSV file that has two columns: `prompt` and `response`, and that is publicly accessible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train.to_csv(\"sports_training_dataset.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Currently, data needs to be uploaded to a publicly accessible web URL so that it can be read for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. Support for privately sharing data with the LLM Engine API is coming shortly. For quick iteration, you can look into tools like Pastebin or Github Gists to quickly host your CSV files in a public manner. We created an example Github Gist you can see [here](https://gist.github.com/tigss/7cec73251a37de72756a3b15eace9965). To use the gist, you can just use the URL given when you click the “Raw” button ([URL](https://gist.githubusercontent.com/tigss/7cec73251a37de72756a3b15eace9965/raw/85d9742890e1e6b0c06468507292893b820c13c9/llm_sample_data.csv)).\n",
    "\n",
    "We've uploaded our CSV file to `s3://scale-demo-datasets/sports/sports_training_dataset.csv`, which maps to a URL of `https://scale-demo-datasets.s3.us-west-2.amazonaws.com/sports/sports_training_dataset.csv`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-Tuning the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we create the fine-tune from our training file via the FineTune API. Note: this can take roughly 15-20 minutes from launching the job to finishing the job with a few hundred examples, as there is a queue of jobs to run.\n",
    "\n",
    "For this section, you will need an API key to interact with Scale. To retrieve your API key, head to [Scale Spellbook](https://spellbook.scale.com/) where you will get an API key on the [settings](https://spellbook.scale.com/settings) page."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: you must have the environment variable SCALE_API_KEY set to your Spellbook API key. \n",
    "\n",
    "from llmengine import FineTune, Completion, Model\n",
    "\n",
    "FineTune.validate_api_key()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_fine_tune_response = FineTune.create(\n",
    "    model=\"llama-2-7b\",\n",
    "    training_file=\"https://scale-demo-datasets.s3.us-west-2.amazonaws.com/sports/sports_training_dataset.csv\",\n",
    "    validation_file=None,\n",
    "    hyperparameters={\"epochs\": \"1\", \"lr\": \"0.0002\"},\n",
    "    suffix=\"my-first-fine-tune\"\n",
    ")\n",
    "\n",
    "fine_tune_id = create_fine_tune_response.fine_tune_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BatchJobStatus.RUNNING\n"
     ]
    }
   ],
   "source": [
    "# Wait for fine tune to complete\n",
    "\n",
    "\n",
    "fine_tune_status = FineTune.get(fine_tune_id).status\n",
    "print(fine_tune_status)\n",
    "if fine_tune_status == \"SUCCESS\":\n",
    "    print(\"Fine-Tune Succeeded!\")\n",
    "elif fine_tune_status in [\"FAILURE\", \"CANCELLED\"]:\n",
    "    raise ValueError(\"Fine-Tune failed\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the fine-tune completes, we can get your fine-tune via looking at all the models available to you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can look at all the models that are available for completions\n",
    "all_models = Model.list().model_endpoints\n",
    "all_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "your_fine_tuned_model = \"llama-2-7b.my-first-fine-tune.2023-07-19-00-48-07\"  # Note: you will have a different model!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the Fine-Tune"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we run our model on the test dataset via the Completions API. Since we trained using a prompt template, use that prompt template when making predictions. Note: you may have to wait a few minutes after the fine-tune succeeds in order for your model to be loaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_classification(prompt: str):\n",
    "    for _ in range(5):\n",
    "        try:\n",
    "            response = Completion.create(\n",
    "                model=your_fine_tuned_model, \n",
    "                prompt=build_prompt(prompt), \n",
    "                max_new_tokens=2, \n",
    "                temperature=0.01\n",
    "            )\n",
    "            return response.output.text.rstrip(\"\\n\")\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "    else:\n",
    "        return \"Error\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test[\"predicted_response\"] = df_test[\"raw_prompt\"].apply(get_classification)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's peek at the data and calculate our test accuracy!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>raw_prompt</th>\n",
       "      <th>response</th>\n",
       "      <th>predicted_response</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>950</th>\n",
       "      <td>From: tedward@cs.cornell.edu (Edward [Ted] Fis...</td>\n",
       "      <td>baseball</td>\n",
       "      <td>baseball</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>951</th>\n",
       "      <td>From: smorris@venus.lerc.nasa.gov (Ron Morris ...</td>\n",
       "      <td>hockey</td>\n",
       "      <td>hockey</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>952</th>\n",
       "      <td>From: shah@pitt.edu (Ravindra S Shah)\\nSubject...</td>\n",
       "      <td>hockey</td>\n",
       "      <td>hockey</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>953</th>\n",
       "      <td>From: timlin@spot.Colorado.EDU (Michael Timlin...</td>\n",
       "      <td>baseball</td>\n",
       "      <td>baseball</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>954</th>\n",
       "      <td>From: gp2011@andy.bgsu.edu (George Pavlic)\\nSu...</td>\n",
       "      <td>hockey</td>\n",
       "      <td>hockey</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            raw_prompt  response  \\\n",
       "950  From: tedward@cs.cornell.edu (Edward [Ted] Fis...  baseball   \n",
       "951  From: smorris@venus.lerc.nasa.gov (Ron Morris ...    hockey   \n",
       "952  From: shah@pitt.edu (Ravindra S Shah)\\nSubject...    hockey   \n",
       "953  From: timlin@spot.Colorado.EDU (Michael Timlin...  baseball   \n",
       "954  From: gp2011@andy.bgsu.edu (George Pavlic)\\nSu...    hockey   \n",
       "\n",
       "    predicted_response  \n",
       "950           baseball  \n",
       "951             hockey  \n",
       "952             hockey  \n",
       "953           baseball  \n",
       "954             hockey  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.98"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_correct = len(df_test[df_test[\"predicted_response\"] == (df_test[\"response\"])])\n",
    "num_correct / len(df_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>raw_prompt</th>\n",
       "      <th>response</th>\n",
       "      <th>predicted_response</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>974</th>\n",
       "      <td>From: maX &lt;maX@maxim.rinaco.msk.su&gt;\\nSubject: ...</td>\n",
       "      <td>hockey</td>\n",
       "      <td>baseball</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            raw_prompt response  \\\n",
       "974  From: maX <maX@maxim.rinaco.msk.su>\\nSubject: ...   hockey   \n",
       "\n",
       "    predicted_response  \n",
       "974           baseball  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test[df_test[\"predicted_response\"] != df_test[\"response\"]]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
